/* Project: Gipsa-lab plugins for OpenVibe
 * AUTHORS AND CONTRIBUTORS: Andreev A., Barachant A., Congedo M., Ionescu,Gelu

 * This file is part of "Gipsa-lab plugins for OpenVibe".
 * You can redistribute it and/or modify it under the terms of the GNU General Public License
 * as published by the Free Software Foundation, either version 3 of the License, or (at your option) any later version.
 * This file is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty
 * of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details.
 * You should have received a copy of the GNU General Public License along with Brain Invaders. If not, see http://www.gnu.org/licenses/.*/
 
#include "ovpCBoxAlgorithmTrainMDM.h"

#include <iostream>
#include <sstream>
#include <system/Memory.h>

#include "ovpRiemannHelper.h"
#include "ovpCBoxAlgorithmProcessMDM.h"

#include <boost/date_time/posix_time/posix_time.hpp>
#include <numeric>      // std::accumulate

using namespace OpenViBE;
using namespace OpenViBE::Kernel;
using namespace OpenViBE::Plugins;

using namespace OpenViBEPlugins;
using namespace OpenViBEPlugins::SignalProcessing;

boolean CBoxAlgorithmTrainMDM::initialize(void)
{
	//> init INPUT stimulation
	m_pStimulationDecoderTrigger=&this->getAlgorithmManager().getAlgorithm(this->getAlgorithmManager().createAlgorithm(OVP_GD_ClassId_Algorithm_StimulationStreamDecoder));
    m_pStimulationDecoderTrigger->initialize();
    ip_pMemoryBufferToDecodeTrigger.initialize(m_pStimulationDecoderTrigger->getInputParameter(OVP_GD_Algorithm_StimulationStreamDecoder_InputParameterId_MemoryBufferToDecode));
    op_pStimulationSetTrigger.initialize(m_pStimulationDecoderTrigger->getOutputParameter(OVP_GD_Algorithm_StimulationStreamDecoder_OutputParameterId_StimulationSet));
	
	IBox& l_rStaticBoxContext=this->getStaticBoxContext();
	//all signal INPUT channels
	for(uint32 i=1; i<l_rStaticBoxContext.getInputCount(); i++) //first one is stimulation
	{
		IAlgorithmProxy* l_pStreamDecoder=&this->getAlgorithmManager().getAlgorithm(this->getAlgorithmManager().createAlgorithm(OVP_GD_ClassId_Algorithm_SignalStreamDecoder));
		l_pStreamDecoder->initialize();
		m_vStreamDecoder.push_back(l_pStreamDecoder);
	}

	//output stimulation 1 
	m_oStimulationEncoder.initialize(*this);

	//UI parameters:
	CString l_sSettingValue;

	//Training file location
	l_rStaticBoxContext.getSettingValue(0, l_sSettingValue);
	std::string l_sFileNamePath=(std::string)this->getConfigurationManager().expand(l_sSettingValue);

	//Train Stimulation
	m_ui64STrainStimulationIdentifier=FSettingValueAutoCast(*this->getBoxAlgorithmContext(), 1);

	//IsP300
	l_rStaticBoxContext.getSettingValue(2, l_sSettingValue);
	m_IsP300=(OpenViBE::boolean)this->getConfigurationManager().expandAsBoolean(l_sSettingValue);

	//Start processing file section
	if (l_sFileNamePath == "") 
	{ 
	    this->getLogManager() << LogLevel_Error << "Parameter filename is empty!\n";
		//return false;
	}

	this->getLogManager() << LogLevel_Info << "Parameter file:" << l_sFileNamePath.c_str() << "\n";

	m_ParamFile.open(l_sFileNamePath.c_str());

	if (m_ParamFile.bad())
	{
		this->getLogManager() << LogLevel_Error << "Could not create parameter file:" << l_sFileNamePath.c_str() << "\n";
		//return false;
	}
    //End processing file section

	m_bStartTrain = false;

	m_vTrainingClasses.clear();

	for(uint32 c=1; c<l_rStaticBoxContext.getInputCount(); c++) //first c=0 is stimulations
	{
		m_vTrainingClasses.push_back(new RClass());
	}
	
	if (m_IsP300 && l_rStaticBoxContext.getInputCount() > 3) 
    {
		this->getLogManager() << LogLevel_Error << "For P300 you need exactly 1 stimulation and 2 signal input channels (target and non-target)." << "\n";
	}
	
	return true;
}

boolean CBoxAlgorithmTrainMDM::uninitialize(void)
{
	IBox& l_rStaticBoxContext=this->getStaticBoxContext();

	//uninit all input channels
	for(uint32 i=0; i<m_vStreamDecoder.size(); i++) //first one is stimulation
	{
		IAlgorithmProxy* l_pStreamDecoder=m_vStreamDecoder[i];
		l_pStreamDecoder->uninitialize();
		this->getAlgorithmManager().releaseAlgorithm(*l_pStreamDecoder);
	}
	m_vStreamDecoder.clear();

	//unint output stimulation - Train completed
	m_oStimulationEncoder.uninitialize();

	//uninit input stimulation
	m_pStimulationDecoderTrigger->uninitialize();
    ip_pMemoryBufferToDecodeTrigger.uninitialize();
	op_pStimulationSetTrigger.uninitialize();
    this->getAlgorithmManager().releaseAlgorithm(*m_pStimulationDecoderTrigger);

	for(uint32 i=0; i<m_vTrainingClasses.size(); i++) //first c=0 is stimulations
	{
		m_vTrainingClasses[i]->vBufferedSignal.clear();
		m_vTrainingClasses[i]->vCovarianceMatrices.clear();
		delete m_vTrainingClasses[i];
	}
	m_vTrainingClasses.clear();
	//end clear matrices

	//if (pFile != NULL) {fclose (pFile);}
	
	return true;
}

boolean CBoxAlgorithmTrainMDM::processInput(uint32 ui32InputIndex)
{
	getBoxAlgorithmContext()->markAlgorithmAsReadyToProcess();
	return true;
}

boolean CBoxAlgorithmTrainMDM::process(void)
{
	uint64 l_ui64TrainDate, l_ui64TrainChunkStartTime, l_ui64TrainChunkEndTime;
	IBox& l_rStaticBoxContext=this->getStaticBoxContext();
	IBoxIO* l_rDynamicBoxContext=getBoxAlgorithmContext()->getDynamicBoxContext();

	//check stimulation channel
	for(uint32 j=0; j<l_rDynamicBoxContext->getInputChunkCount(0); j++)
	{
		ip_pMemoryBufferToDecodeTrigger=l_rDynamicBoxContext->getInputChunk(0, j);
		m_pStimulationDecoderTrigger->process();

		if(m_pStimulationDecoderTrigger->isOutputTriggerActive(OVP_GD_Algorithm_StimulationStreamDecoder_OutputTriggerId_ReceivedHeader))
		{
			m_oStimulationEncoder.encodeHeader(0);
			l_rDynamicBoxContext->markOutputAsReadyToSend(0,l_rDynamicBoxContext->getInputChunkStartTime(0, j),l_rDynamicBoxContext->getInputChunkEndTime(0, j));
		}

		//BUFFER
		if(m_pStimulationDecoderTrigger->isOutputTriggerActive(OVP_GD_Algorithm_StimulationStreamDecoder_OutputTriggerId_ReceivedBuffer))
		{
            for(uint32 k=0; k<op_pStimulationSetTrigger->getStimulationCount(); k++)
			{
                 // Check for Train Stimulation
				m_bStartTrain |= (op_pStimulationSetTrigger->getStimulationIdentifier(k)==m_ui64STrainStimulationIdentifier);
				if (op_pStimulationSetTrigger->getStimulationIdentifier(k)==m_ui64STrainStimulationIdentifier) this->getLogManager() << LogLevel_Warning << "Start train stimulation detected.\n";
			}

			if(m_bStartTrain)
			{
				l_ui64TrainDate = l_rDynamicBoxContext->getInputChunkEndTime(0, j);
				l_ui64TrainChunkStartTime = l_rDynamicBoxContext->getInputChunkStartTime(0, j);
				l_ui64TrainChunkEndTime = l_rDynamicBoxContext->getInputChunkEndTime(0, j);
			}

		  }

		l_rDynamicBoxContext->markInputAsDeprecated(0,j);
	}

	if (m_bStartTrain) //the user has finished the training phase, now we start calculating on the buffered data (which is also called training) 
	{
		#if defined(HAS_CONCURRENCY)
		this->getLogManager() << LogLevel_Info << "Concurrent execution is enabled: "
		#if defined(HAS_TBB)
		   << "TBB\n";
		#else 
		   << "PPT\n";
		#endif
		;
		#endif
		this->getLogManager() << LogLevel_Info << "Starting actual calculations.\n";
		boost::posix_time::ptime t1 = boost::posix_time::microsec_clock::local_time();

		//1. We cache the data
		for(uint32 c=1; c<l_rStaticBoxContext.getInputCount(); c++) //first c=0 is stimulations and here we want only signal data
	    {
			for(uint32 i=0; i<l_rDynamicBoxContext->getInputChunkCount(c); i++)
			{
				TParameterHandler<const IMemoryBuffer*> ip_pMemoryBuffer(m_vStreamDecoder[c-1]->getInputParameter(OVP_GD_Algorithm_SignalStreamDecoder_InputParameterId_MemoryBufferToDecode));
				ip_pMemoryBuffer=l_rDynamicBoxContext->getInputChunk(c, i);
				m_vStreamDecoder[c-1]->process();

				//HEADER
				if(m_vStreamDecoder[c-1]->isOutputTriggerActive(OVP_GD_Algorithm_SignalStreamDecoder_OutputTriggerId_ReceivedHeader))
				{
				}

				//BUFFER
				if(m_vStreamDecoder[c-1]->isOutputTriggerActive(OVP_GD_Algorithm_SignalStreamDecoder_OutputTriggerId_ReceivedBuffer))
				{
					TParameterHandler<IMatrix*> ip_pMatrix(m_vStreamDecoder[c-1]->getOutputParameter(OVP_GD_Algorithm_SignalStreamDecoder_OutputParameterId_Matrix));
					itpp::mat X = convert(*ip_pMatrix);

					//m_vBufferedSignalPerChannel[c-1]->push_back(X);
					m_vTrainingClasses[c-1]->vBufferedSignal.push_back(X);
				}

				l_rDynamicBoxContext->markInputAsDeprecated(c-1, i);
			}
		}

		//quick data check
		for (vector_type<RClass*>::iterator signalChannelsIt=m_vTrainingClasses.begin(); signalChannelsIt!=m_vTrainingClasses.end(); signalChannelsIt++)
		{
			if ((*signalChannelsIt)->vBufferedSignal.size() == 0)
			{
				this->getLogManager() << LogLevel_Error << "Input channel contains no data!\n";
				return false;
			}
		}

		OpenViBE::boolean trainSuccess = train();
		if (!trainSuccess) return false;

		//Print time
		boost::posix_time::ptime t2 = boost::posix_time::microsec_clock::local_time();
        boost::posix_time::time_duration diff = t2 - t1;
		this->getLogManager() << LogLevel_Info << "Training completed! Calculation time: " << (double)diff.total_milliseconds() / double(1000) << " seconds. All OK! Saving file ...\n";
		
		autoValidation();

		saveDistances();

		//Save output fo file
		saveFile();

		//Send Train completed to the next box
		m_oStimulationEncoder.getInputStimulationSet()->appendStimulation(OVTK_StimulationId_TrainCompleted, l_ui64TrainDate, 0);
		m_oStimulationEncoder.encodeBuffer(0);
		l_rDynamicBoxContext->markOutputAsReadyToSend(0, l_ui64TrainChunkStartTime, l_ui64TrainChunkEndTime);
		
		m_bStartTrain = false;
	}

	return true;
}

boolean CBoxAlgorithmTrainMDM::train()
{
	//We sum the input in the P300 case (for a mean)
	if (m_IsP300 && m_vTrainingClasses.size() == 2)
	{
		bool firstRun = true;
		for(std::vector<itpp::mat>::iterator P300epochsIt=m_vTrainingClasses[1]->vBufferedSignal.begin();P300epochsIt!=m_vTrainingClasses[1]->vBufferedSignal.end();P300epochsIt++)
		{
			itpp::mat X = *P300epochsIt;
			if (firstRun)
			{
				m_P1 = itpp::zeros(X.rows(),X.cols());
				firstRun = false;
			}

			m_P1 = m_P1 + X;
		}
	}
			
	if (m_IsP300) //calculate mean in P1
	{
		if (m_P1.rows()==0 || m_P1.cols()==0)
		{
			this->getLogManager() << LogLevel_Error << "Training failed. P1 is bad!\n";
			return false;
		}

		m_P1 = m_P1 / (int)m_vTrainingClasses[1]->vBufferedSignal.size(); //we calculated a mean for the second channel (second chnannel is "target" for P300)
	}

	//2. Calculate covariance matrices for each channel
	#if defined(HAS_CONCURRENCY)
    parallel_for(0, int32(m_vTrainingClasses.size()), [&](size_t i)
	#else
	for(int32 i=0;i<m_vTrainingClasses.size();i++)
	#endif
	{
		for (std::vector<itpp::mat>::iterator it=m_vTrainingClasses[i]->vBufferedSignal.begin(); it!=m_vTrainingClasses[i]->vBufferedSignal.end(); it++)
		{
				itpp::mat P;
				itpp::mat X1 = *it;

				if (m_IsP300) //second channel is target for P300
				{
					//concatenate 
					//[P1]
					//[X2]
					itpp::mat XC = itpp::concat_vertical(m_P1,X1);

					P = itpp::cov(XC.transpose(),false);
				}
				else
				{
					P = itpp::cov(X1.transpose(),false); //non target channel for P300 - calculate covariance
				}
			  
				m_vTrainingClasses[i]->vCovarianceMatrices.push_back(P); //0 based c
		}
	}
	#if defined(HAS_CONCURRENCY) 
	); 
    #endif

	//4. Final calculation - the mean (the barycenters for each class)
		
	//make quick check
	for(uint32 i=0;i<m_vTrainingClasses.size();i++)
	{
		if (m_vTrainingClasses[i]->vCovarianceMatrices.size()==0)
		{
			this->getLogManager() << LogLevel_Error << "Training failed! Matrices are bad!\n";
			return false;
		}
	}

	//calculate and store result matrices
	boolean badMatrixDetected = false;
	#if defined(HAS_CONCURRENCY)
	parallel_for(0, int32(m_vTrainingClasses.size()), [&](size_t i)
	#else
	for(int32 i=0;i<m_vTrainingClasses.size();i++)
	#endif
	{
		m_vTrainingClasses[i]->ResultMean = Riemann::mean(m_vTrainingClasses[i]->vCovarianceMatrices);//we calculate the mean covariance over all covariance matrices for a specific class
	}
	#if defined(HAS_CONCURRENCY) 
	); 
    #endif

	return true;
}

itpp::mat CBoxAlgorithmTrainMDM::convert(const OpenViBE::IMatrix& rMatrix)
{
		itpp::mat l_oResult(
			rMatrix.getDimensionSize(1),
			rMatrix.getDimensionSize(0));

		System::Memory::copy(l_oResult._data(), rMatrix.getBuffer(), rMatrix.getBufferElementCount()*sizeof(float64));
		return l_oResult.transpose();
}

void CBoxAlgorithmTrainMDM::saveFile()
{
	if (!m_ParamFile.bad())
	{
		uint32 count = (m_IsP300) ? m_vTrainingClasses.size() +1 : m_vTrainingClasses.size();

		m_ParamFile << "count: " << count << std::endl;

		uint32 i=0;
		for (uint32 i=0;i<m_vTrainingClasses.size();i++)
		{
			m_ParamFile << "C" << i << std::endl << m_vTrainingClasses[i]->ResultMean << std::endl;
		}

		if (m_IsP300)
		{
			m_ParamFile << "P1" << std::endl << m_P1 << std::endl;
		}

		m_ParamFile << std::endl;

		for (uint32 i=0;i<m_vTrainingClasses.size();i++)
		{
			m_ParamFile << "C" << i << "_mean_std=" << m_vTrainingClasses[i]->DistanceMean << ";" << m_vTrainingClasses[i]->DistanceSTD << "\n";
		}

		m_ParamFile.flush();
		m_ParamFile.close();

		this->getLogManager() << LogLevel_Info << "Parameters file saved!\n";
	}
	else
	{
		this->getLogManager() << LogLevel_Warning << "Result calculated successfully, but could not be saved to a file!\n";
		//potentially print to console
	}
}

void CBoxAlgorithmTrainMDM::autoValidation()
{
	this->getLogManager() << LogLevel_Info << "Accuracy in auto-validation:\n";

	vector_type<int> l_vAccuracy=vector_type<int>(m_vTrainingClasses.size());

	//construct parameter with the means from all classes
	vector_type<itpp::mat> l_vResultMean;
	for(int i=0;i<m_vTrainingClasses.size();i++)
	{
		l_vResultMean.push_back(m_vTrainingClasses[i]->ResultMean);
	}

	//the signal m_vCovarianceMatricesPerChannel per channel corresponds per class
	//#if defined(HAS_CONCURRENCY)
    //parallel_for(0, int(m_vTrainingClasses.size()), [&](size_t i)
	//#else
	for(int i=0;i<m_vTrainingClasses.size();i++)
	//#endif
	{
		//test per class 
		for (std::vector<itpp::mat>::iterator it=m_vTrainingClasses[i]->vCovarianceMatrices.begin(); it!=m_vTrainingClasses[i]->vCovarianceMatrices.end(); it++)
		{
			std::vector<OpenViBE::float64> l_vDistances; //the distance between point *it and all the classes l_vResultMean

			//classify
			OpenViBE::uint32 index = SignalProcessing::CBoxAlgorithmProcessMDM::ApplyMDM(*it, l_vResultMean, l_vDistances);

			for(int k=0;k<m_vTrainingClasses.size();k++)
			{
				m_vTrainingClasses[k]->v_RiemmanDistances.push_back(l_vDistances[k]);
			}
				
			if (index == i) //if the predicted class match the expected
				l_vAccuracy[i]++;
		}
	}
	//#if defined(HAS_CONCURRENCY) 
	//); 
    //#endif

	uint32 totalTrials = 0;
	uint32 totalPosistive = 0; //total classifications that match the training
	//print
	for(uint32 i=0;i<m_vTrainingClasses.size();i++)
	{
		totalTrials = totalTrials + m_vTrainingClasses[i]->vCovarianceMatrices.size();
		totalPosistive = totalPosistive + l_vAccuracy[i];
		this->getLogManager() << LogLevel_Info << "Class: C" << intToString(i).c_str() << " "<< l_vAccuracy[i] << " / " << intToString(m_vTrainingClasses[i]->vCovarianceMatrices.size()).c_str() << "\n";
	}
	this->getLogManager() << LogLevel_Info << "Total data fit: " << totalPosistive * 100 / totalTrials << "%\n";
}

void CBoxAlgorithmTrainMDM::saveDistances()
{
	for(int k=0;k<m_vTrainingClasses.size();k++)
	{
		std::vector<double> v = m_vTrainingClasses[k]->v_RiemmanDistances;
		double sum = std::accumulate(v.begin(), v.end(), 0.0);
        double mean = sum / v.size();

		m_vTrainingClasses[k]->DistanceMean = mean;
		std::cout << m_vTrainingClasses[k]->DistanceMean << " ";

	   std::vector<double> diff(v.size());
       std::transform(v.begin(), v.end(), diff.begin(),
               std::bind2nd(std::minus<double>(), mean));
       
	   double sq_sum = std::inner_product(diff.begin(), diff.end(), diff.begin(), 0.0);
       double stdev = std::sqrt(sq_sum / v.size());

	   m_vTrainingClasses[k]->DistanceSTD = stdev;
	   std::cout << m_vTrainingClasses[k]->DistanceSTD << std::endl;
	}
	
}

std::string CBoxAlgorithmTrainMDM::intToString(int number)
{
	std::stringstream ss;
    ss << number;
    
	return ss.str();
}
